{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Char RNN 生成文本\n",
    "在循环神经网络的章节我们了解到其非常擅长处理序列问题，那么对于文本而言，其也相当于一个序列，因为每句话都是由单词或者汉子按照序列顺序组成的，所以我们也能够使用 RNN 对其进行处理，那么如何能够生成文本呢？其实原理非常简单，下面我们来讲一讲 Char RNN。\n",
    "\n",
    "### 训练过程\n",
    "前面我们介绍过 RNN 的输入和输出存在多种关系，比如一对多，多对多等等，不同的输入对应着不同的应用，比如多对多可以用来做机器翻译等等，今天我们要讲的 Char RNN 在训练网络的时候是一个相同长度的多对多类型，也就是输入一个序列，输出一个吸纳共同长度的序列。\n",
    "\n",
    "具体的网络训练过程如下\n",
    "\n",
    "<img src=https://ws1.sinaimg.cn/large/006tNc79gy1fob5kq3r8jj30mt09dq2r.jpg width=700>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到上面的网络流程中，输入是一个序列 \"床 前 明 月 光\"，输出也是一个序列 \"前 明 月 光 床\"。如果你仔细观察可以发现网络的每一步输出都是下一步的输入，这是不是某种巧合呢？\n",
    "\n",
    "并不是的，这就是 Char RNN 的设计思路，对于任意的一句话，比如 \"我喜欢小猫\"，我们可以将其拆分为 Char RNN 的训练集，输入就是 \"我 喜 欢 小 猫\"，这构成了长度为 5 的序列，网络的每一步输出就是 \"喜 欢 小 猫 我\"。当然对于一个序列，其最后一个字符后面并没有其他的字符，所以有多种方式选择，比如将序列的第一个字符作为其输出，也就是 \"光\" 的输出是 \"床\"，或者将其本身作为输出，也就是 \"光\" 的输出是 \"光\"。\n",
    "\n",
    "这样设计有什么好处呢？因为训练的过程是一个监督的训练的过程，所以并不能看出这样做的意义，在生成文本的过程我们就能够看出这样做到底有什么好处。\n",
    "\n",
    "### 生成文本\n",
    "我们直接讲解一下生成文本的过程，就能够直观的解释训练过程的原因。\n",
    "\n",
    "首先需要输入网络一段初始的序列进行预热，预热的过程并不需要实际的输出结果，只是为了生成拥有记忆效果的隐藏状态，并将隐藏状态保留下来，接着我们开始正式生成文本，不断地生成新的句子，这个过程是可以无限循环下去，或者到达我们的要求输出长度，具体可以看看下面的图示\n",
    "\n",
    "<img src=https://ws2.sinaimg.cn/large/006tNc79gy1fob5z06w1uj30qh09m0sl.jpg width=800>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过上面的例子可以看到，我们能够不断地将前面输出的文字重新输入到网络，不断循环递归，最后生成我们想要的长度的句子，是不是很简单呢？\n",
    "\n",
    "下面我们用 PyTorch 来具体实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "我们使用古诗来作为例子，读取这个数据，看看其长什么样子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:52.315656Z",
     "start_time": "2018-02-18T03:28:52.286844Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('./dataset/poetry.txt', 'r') as f:\n",
    "    poetry_corpus = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:52.331908Z",
     "start_time": "2018-02-18T03:28:52.317790Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'寒随穷律变，春逐鸟声开。\\n初风飘带柳，晚雪间花梅。\\n碧林青旧竹，绿沼翠新苔。\\n芝田初雁去，绮树巧莺来。\\n晚霞聊自怡，初晴弥可喜。\\n日晃百花色，风动千林翠。\\n池鱼跃不同，园鸟声还异。\\n寄言博通者，知予物'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "poetry_corpus[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:52.338277Z",
     "start_time": "2018-02-18T03:28:52.334069Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "942681"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 看看字符数\n",
    "len(poetry_corpus)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "为了可视化比较方便，我们将一些其他字符替换成空格"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:52.353185Z",
     "start_time": "2018-02-18T03:28:52.340405Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'寒随穷律变 春逐鸟声开  初风飘带柳 晚雪间花梅  碧林青旧竹 绿沼翠新苔  芝田初雁去 绮树巧莺来  晚霞聊自怡 初晴弥可喜  日晃百花色 风动千林翠  池鱼跃不同 园鸟声还异  寄言博通者 知予物'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "poetry_corpus = poetry_corpus.replace('\\n', ' ').replace('\\r', ' ').replace('，', ' ').replace('。', ' ')\n",
    "poetry_corpus[:100]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 文本数值表示\n",
    "对于每个文字，电脑并不能有效地识别，所以必须做一个转换，将文字转换成数字，对所有非重复的字符，可以从 0 开始建立索引\n",
    "\n",
    "同时为了节省内存的开销，可以词频比较低的字去掉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:52.642640Z",
     "start_time": "2018-02-18T03:28:52.355357Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class TextConverter(object):\n",
    "    def __init__(self, text_path, max_vocab=5000):\n",
    "        \"\"\"建立一个字符索引转换器\n",
    "        \n",
    "        Args:\n",
    "            text_path: 文本位置\n",
    "            max_vocab: 最大的单词数量\n",
    "        \"\"\"\n",
    "        \n",
    "        with open(text_path, 'r') as f:\n",
    "            text = f.read()\n",
    "        text = text.replace('\\n', ' ').replace('\\r', ' ').replace('，', ' ').replace('。', ' ')\n",
    "        # 去掉重复的字符\n",
    "        vocab = set(text)\n",
    "\n",
    "        # 如果单词总数超过最大数值，去掉频率最低的\n",
    "        vocab_count = {}\n",
    "        \n",
    "        # 计算单词出现频率并排序\n",
    "        for word in vocab:\n",
    "            vocab_count[word] = 0\n",
    "        for word in text:\n",
    "            vocab_count[word] += 1\n",
    "        vocab_count_list = []\n",
    "        for word in vocab_count:\n",
    "            vocab_count_list.append((word, vocab_count[word]))\n",
    "        vocab_count_list.sort(key=lambda x: x[1], reverse=True)\n",
    "        \n",
    "        # 如果超过最大值，截取频率最低的字符\n",
    "        if len(vocab_count_list) > max_vocab:\n",
    "            vocab_count_list = vocab_count_list[:max_vocab]\n",
    "        vocab = [x[0] for x in vocab_count_list]\n",
    "        self.vocab = vocab\n",
    "\n",
    "        self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)}\n",
    "        self.int_to_word_table = dict(enumerate(self.vocab))\n",
    "\n",
    "    @property\n",
    "    def vocab_size(self):\n",
    "        return len(self.vocab) + 1\n",
    "\n",
    "    def word_to_int(self, word):\n",
    "        if word in self.word_to_int_table:\n",
    "            return self.word_to_int_table[word]\n",
    "        else:\n",
    "            return len(self.vocab)\n",
    "\n",
    "    def int_to_word(self, index):\n",
    "        if index == len(self.vocab):\n",
    "            return '<unk>'\n",
    "        elif index < len(self.vocab):\n",
    "            return self.int_to_word_table[index]\n",
    "        else:\n",
    "            raise Exception('Unknown index!')\n",
    "\n",
    "    def text_to_arr(self, text):\n",
    "        arr = []\n",
    "        for word in text:\n",
    "            arr.append(self.word_to_int(word))\n",
    "        return np.array(arr)\n",
    "\n",
    "    def arr_to_text(self, arr):\n",
    "        words = []\n",
    "        for index in arr:\n",
    "            words.append(self.int_to_word(index))\n",
    "        return \"\".join(words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.016322Z",
     "start_time": "2018-02-18T03:28:52.645616Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "convert = TextConverter('./dataset/poetry.txt', max_vocab=10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以可视化一下数字表示的字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.025196Z",
     "start_time": "2018-02-18T03:28:53.018514Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "寒随穷律变 春逐鸟声开\n",
      "[ 40 166 358 935 565   0  10 367 108  63  78]\n"
     ]
    }
   ],
   "source": [
    "# 原始文本字符\n",
    "txt_char = poetry_corpus[:11]\n",
    "print(txt_char)\n",
    "\n",
    "# 转化成数字\n",
    "print(convert.text_to_arr(txt_char))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造时序样本数据\n",
    "为了输入到循环神经网络中进行训练，我们需要构造时序样本的数据，因为前面我们知道，循环神经网络存在着长时依赖的问题，所以说我们不能将所有的文本作为一个序列一起输入到循环神经网络中，我们需要将整个文本分成很多很多个序列组成 batch 输入到网络中，只要我们定好每个序列的长度，那么序列个数也就被决定了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.036447Z",
     "start_time": "2018-02-18T03:28:53.027222Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47134\n"
     ]
    }
   ],
   "source": [
    "n_step = 20\n",
    "\n",
    "# 总的序列个数\n",
    "num_seq = int(len(poetry_corpus) / n_step)\n",
    "\n",
    "# 去掉最后不足一个序列长度的部分\n",
    "text = poetry_corpus[:num_seq*n_step]\n",
    "\n",
    "print(num_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们将序列中所有的文字转换成数字表示，重新排列成 (num_seq x n_step) 的矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.258155Z",
     "start_time": "2018-02-18T03:28:53.038479Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.921749Z",
     "start_time": "2018-02-18T03:28:53.260507Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([47134, 20])\n",
      "\n",
      "  40\n",
      " 166\n",
      " 358\n",
      " 935\n",
      " 565\n",
      "   0\n",
      "  10\n",
      " 367\n",
      " 108\n",
      "  63\n",
      "  78\n",
      "   0\n",
      "   0\n",
      " 150\n",
      "   4\n",
      " 443\n",
      " 284\n",
      " 182\n",
      "   0\n",
      " 131\n",
      "[torch.LongTensor of size 20]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "arr = convert.text_to_arr(text)\n",
    "arr = arr.reshape((num_seq, -1))\n",
    "arr = torch.from_numpy(arr)\n",
    "\n",
    "print(arr.shape)\n",
    "print(arr[0, :])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "据此，我们可以构建 PyTorch 中的数据读取来训练网络，这里我们将最后一个字符的输出 label 定为输入的第一个字符，也就是\"床前明月光\"的输出是\"前明月光床\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.945768Z",
     "start_time": "2018-02-18T03:28:53.925488Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class TextDataset(object):\n",
    "    def __init__(self, arr):\n",
    "        self.arr = arr\n",
    "        \n",
    "    def __getitem__(self, item):\n",
    "        x = self.arr[item, :]\n",
    "        \n",
    "        # 构造 label\n",
    "        y = torch.zeros(x.shape)\n",
    "        # 将输入的第一个字符作为最后一个输入的 label\n",
    "        y[:-1], y[-1] = x[1:], x[0]\n",
    "        return x, y\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.arr.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.950296Z",
     "start_time": "2018-02-18T03:28:53.947697Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_set = TextDataset(arr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以取出其中一个数据集参看一下是否是我们描述的这样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:53.957705Z",
     "start_time": "2018-02-18T03:28:53.952232Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "寒随穷律变 春逐鸟声开  初风飘带柳 晚\n",
      "随穷律变 春逐鸟声开  初风飘带柳 晚寒\n"
     ]
    }
   ],
   "source": [
    "x, y = train_set[0]\n",
    "print(convert.arr_to_text(x.numpy()))\n",
    "print(convert.arr_to_text(y.numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 建立模型\n",
    "模型可以定义成非常简单的三层，第一层是词嵌入，第二层是 RNN 层，因为最后是一个分类问题，所以第三层是线性层，最后输出预测的字符。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:54.022455Z",
     "start_time": "2018-02-18T03:28:53.959687Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "use_gpu = True\n",
    "\n",
    "class CharRNN(nn.Module):\n",
    "    def __init__(self, num_classes, embed_dim, hidden_size, \n",
    "                 num_layers, dropout):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.word_to_vec = nn.Embedding(num_classes, embed_dim)\n",
    "        self.rnn = nn.GRU(embed_dim, hidden_size, num_layers, dropout)\n",
    "        self.project = nn.Linear(hidden_size, num_classes)\n",
    "\n",
    "    def forward(self, x, hs=None):\n",
    "        batch = x.shape[0]\n",
    "        if hs is None:\n",
    "            hs = Variable(\n",
    "                torch.zeros(self.num_layers, batch, self.hidden_size))\n",
    "            if use_gpu:\n",
    "                hs = hs.cuda()\n",
    "        word_embed = self.word_to_vec(x)  # (batch, len, embed)\n",
    "        word_embed = word_embed.permute(1, 0, 2)  # (len, batch, embed)\n",
    "        out, h0 = self.rnn(word_embed, hs)  # (len, batch, hidden)\n",
    "        le, mb, hd = out.shape\n",
    "        out = out.view(le * mb, hd)\n",
    "        out = self.project(out)\n",
    "        out = out.view(le, mb, -1)\n",
    "        out = out.permute(1, 0, 2).contiguous()  # (batch, len, hidden)\n",
    "        return out.view(-1, out.shape[2]), h0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型\n",
    "在训练模型的时候，我们知道这是一个分类问题，所以可以使用交叉熵作为 loss 函数，在语言模型中，我们通常使用一个新的指标来评估结果，这个指标叫做困惑度(perplexity)，可以简单地看作对交叉熵取指数，这样其范围就是 $[1, +\\infty]$，也是越小越好。\n",
    "\n",
    "另外，我们前面讲过 RNN 存在着梯度爆炸的问题，所以我们需要进行梯度裁剪，在 pytorch 中使用 `torch.nn.utils.clip_grad_norm` 就能够轻松实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:54.030508Z",
     "start_time": "2018-02-18T03:28:54.024511Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "batch_size = 128\n",
    "train_data = DataLoader(train_set, batch_size, True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:28:59.955521Z",
     "start_time": "2018-02-18T03:28:54.032512Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from mxtorch.trainer import ScheduledOptim\n",
    "\n",
    "model = CharRNN(convert.vocab_size, 512, 512, 2, 0.5)\n",
    "if use_gpu:\n",
    "    model = model.cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "basic_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "optimizer = ScheduledOptim(basic_optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:31:48.754799Z",
     "start_time": "2018-02-18T03:28:59.957657Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, perplexity is: 290.865, lr:1.0e-03\n",
      "epoch: 2, perplexity is: 190.468, lr:1.0e-03\n",
      "epoch: 3, perplexity is: 124.909, lr:1.0e-03\n",
      "epoch: 4, perplexity is: 88.715, lr:1.0e-03\n",
      "epoch: 5, perplexity is: 67.819, lr:1.0e-03\n",
      "epoch: 6, perplexity is: 53.798, lr:1.0e-03\n",
      "epoch: 7, perplexity is: 43.619, lr:1.0e-03\n",
      "epoch: 8, perplexity is: 36.032, lr:1.0e-03\n",
      "epoch: 9, perplexity is: 30.195, lr:1.0e-03\n",
      "epoch: 10, perplexity is: 25.569, lr:1.0e-03\n",
      "epoch: 11, perplexity is: 21.868, lr:1.0e-03\n",
      "epoch: 12, perplexity is: 18.918, lr:1.0e-03\n",
      "epoch: 13, perplexity is: 16.482, lr:1.0e-03\n",
      "epoch: 14, perplexity is: 14.505, lr:1.0e-03\n",
      "epoch: 15, perplexity is: 12.870, lr:1.0e-03\n",
      "epoch: 16, perplexity is: 11.489, lr:1.0e-03\n",
      "epoch: 17, perplexity is: 10.358, lr:1.0e-03\n",
      "epoch: 18, perplexity is: 9.416, lr:1.0e-03\n",
      "epoch: 19, perplexity is: 8.619, lr:1.0e-03\n",
      "epoch: 20, perplexity is: 7.905, lr:1.0e-03\n"
     ]
    }
   ],
   "source": [
    "epochs = 20\n",
    "for e in range(epochs):\n",
    "    train_loss = 0\n",
    "    for data in train_data:\n",
    "        x, y = data\n",
    "        y = y.long()\n",
    "        if use_gpu:\n",
    "            x = x.cuda()\n",
    "            y = y.cuda()\n",
    "        x, y = Variable(x), Variable(y)\n",
    "\n",
    "        # Forward.\n",
    "        score, _ = model(x)\n",
    "        loss = criterion(score, y.view(-1))\n",
    "\n",
    "        # Backward.\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        # Clip gradient.\n",
    "        nn.utils.clip_grad_norm(model.parameters(), 5)\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.data[0]\n",
    "    print('epoch: {}, perplexity is: {:.3f}, lr:{:.1e}'.format(e+1, np.exp(train_loss / len(train_data)), optimizer.lr))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，训练完模型之后，我们能够到达 2.72 左右的困惑度，下面我们就可以开始生成文本了。\n",
    "\n",
    "### 生成文本\n",
    "生成文本的过程非常简单，前面已将讲过了，给定了开始的字符，然后不断向后生成字符，将生成的字符作为新的输入再传入网络。\n",
    "\n",
    "这里需要注意的是，为了增加更多的随机性，我们会在预测的概率最高的前五个里面依据他们的概率来进行随机选择。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:31:48.770181Z",
     "start_time": "2018-02-18T03:31:48.758123Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pick_top_n(preds, top_n=5):\n",
    "    top_pred_prob, top_pred_label = torch.topk(preds, top_n, 1)\n",
    "    top_pred_prob /= torch.sum(top_pred_prob)\n",
    "    top_pred_prob = top_pred_prob.squeeze(0).cpu().numpy()\n",
    "    top_pred_label = top_pred_label.squeeze(0).cpu().numpy()\n",
    "    c = np.random.choice(top_pred_label, size=1, p=top_pred_prob)\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-02-18T03:31:48.860330Z",
     "start_time": "2018-02-18T03:31:48.772317Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generate text is: 天青色等烟雨 片帆天际波中象璧 不似到仙林何在 新春山月低心出 波透兔中\n"
     ]
    }
   ],
   "source": [
    "begin = '天青色等烟雨'\n",
    "text_len = 30\n",
    "\n",
    "model = model.eval()\n",
    "samples = [convert.word_to_int(c) for c in begin]\n",
    "input_txt = torch.LongTensor(samples)[None]\n",
    "if use_gpu:\n",
    "    input_txt = input_txt.cuda()\n",
    "input_txt = Variable(input_txt)\n",
    "_, init_state = model(input_txt)\n",
    "result = samples\n",
    "model_input = input_txt[:, -1][:, None]\n",
    "for i in range(text_len):\n",
    "    out, init_state = model(model_input, init_state)\n",
    "    pred = pick_top_n(out.data)\n",
    "    model_input = Variable(torch.LongTensor(pred))[None]\n",
    "    if use_gpu:\n",
    "        model_input = model_input.cuda()\n",
    "    result.append(pred[0])\n",
    "text = convert.arr_to_text(result)\n",
    "print('Generate text is: {}'.format(text))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后可以看到，生成的文本已经想一段段的话了"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
